import os
import argparse
import torch


from PIL import Image
import time
import joblib

import numpy as np

from torchvision import transforms as transforms

parser = argparse.ArgumentParser(description='pretrained mobilenet')
parser.add_argument('--input_txt', default='', type=str, help='dataset path')
parser.add_argument('--save_folder', default='./testout/', type=str, help='Dir to save txt results')
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
parser.add_argument('--split',default=0, type=int)
args = parser.parse_args()


class dataset(torch.utils.data.Dataset):
    def __init__(self, list, transforms):
        self.image_list=[line.strip() for line in open(list, 'r')]
        self.transforms = transforms


    def __getitem__(self, index):
        _path = self.image_list[index]
        data = Image.open(_path).convert('RGB')
        return self.transforms(data), _path

    def __len__(self):
        return len(self.image_list)



if __name__ == '__main__':
    torch.set_grad_enabled(False)
    # net and model
    from pytorch_pretrained_vit import ViT
    import timm

    BS=8
    # model = ViT('L_32', pretrained=True)
    model = timm.create_model('resnetv2_50x3_bitm_in21k', pretrained=True)
    model.eval()
    print('Finished loading model!')
    print(model)
    # cudnn.benchmark = True
    device = torch.device("cpu" if args.cpu else "cuda")
    model = model.to(device)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        # transforms.Resize((480, 480)),
        transforms.CenterCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    test_loader = torch.utils.data.DataLoader(
        dataset(args.input_txt, preprocess),
        batch_size=BS, shuffle=False, num_workers=4,
        pin_memory= not args.cpu
    )
    num_images = len(test_loader.dataset)

    # testing begin
    amaxes = []
    paths = []
    for i, (imgs, image_paths) in enumerate(test_loader):
        now = time.time()
        imgs = imgs.to(device)
        out = model(imgs)
        amax = torch.argmax(out, dim=-1).cpu().numpy()
        amaxes.extend(amax)
        paths.extend(image_paths)
        if i % 10 == 0:
            print(f"im_detect: {i*BS + 1:5}/{num_images} Time: {(time.time()-now):.3f}s",
                  f"== {BS/(time.time()-now ) :.1f}Hz", flush=True)
            print(len(np.unique(amaxes)),np.max(amaxes),np.mean(amaxes))
            print(image_paths[-5:])

    joblib.dump(amaxes, os.path.join(args.save_folder, str(args.split)+'-21k_dets.pkl'))
    joblib.dump(paths, os.path.join(args.save_folder, str(args.split) +'-21k_paths.pkl'))

